package com.lc.w455;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class c {

    public static void main(String[] args) throws IOException {

    }

    private int ans = 0;

    public int minIncrease(int n, int[][] edges, int[] cost) {
        List<Integer>[] g = new ArrayList[n];
        Arrays.setAll(g, i -> new ArrayList<>());
        for (int[] e : edges) {
            int x = e[0], y = e[1];
            g[x].add(y);
            g[y].add(x);
        }
        g[0].add(-1);

        dfs(0, -1, 0, g, cost);
        return ans;
    }

    public long dfs(int x, int fa, long pathSum, List<Integer>[] g, int[] cost) {
        pathSum += cost[x];
        if (g[x].size() == 1) {
            return pathSum;
        }

        long maxS = 0;
        int cnt = 0;
        for (int y : g[x]) {
            if (y == fa) {
                continue;
            }
            long mx = dfs(y, x, pathSum, g, cost);
            if (mx > maxS) {
                maxS = mx;
                cnt = 1;
            } else if (mx == maxS) {
                cnt++;
            }
        }

        ans += g[x].size() - 1 - cnt;
        return maxS;
    }
}
